# Copyright 2022 The GPflow Contributors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Code for building GPflow's documentation for a specified branch.
"""
import argparse
import json
import shutil
import subprocess
from datetime import datetime
from itertools import chain
from pathlib import Path
from time import perf_counter
from typing import Collection, Optional, Union

from generate_module_rst import generate_module_rst
from tabulate import tabulate
from versions import Branch

import gpflow

_SRC = Path(__file__).parent
_SPHINX_SRC = _SRC / "sphinx"
_NOTEBOOKS_SRC = _SPHINX_SRC / "notebooks"

_TMP = Path("/tmp/gpflow_build_docs")
_BUILD_TMP = _TMP / "build"
_NOTEBOOKS_TMP = _BUILD_TMP / "notebooks"
_DOCTREE_TMP = _TMP / "doctree"


def _post_process_ipynb(ipynb_relative_path: Path) -> None:
    """
    Post processes notebooks, mostly to remove stuff the end-user shouln't see.

    I've kind of got a hunch there should be a prettier way to do this, but this is simple enough
    and it works...
    """
    ipynb_path = _NOTEBOOKS_TMP / ipynb_relative_path

    with open(ipynb_path, "rt") as f_read:
        notebook = json.load(f_read)

    new_cells = []
    execution_count = 0
    hide_input_count = 0
    remove_cell_count = 0

    for cell in notebook["cells"]:
        remove_cell = False

        if cell["cell_type"] == "code":
            print("--------------------------------------------------")
            print("".join(cell["source"]))
            execution_times = cell["metadata"]["execution"]
            start = datetime.fromisoformat(execution_times["iopub.execute_input"][:-1])
            end = datetime.fromisoformat(execution_times["iopub.status.idle"][:-1])
            print("Execution time:", end - start)

            hiding: Union[bool, int] = False
            source = cell["source"]
            new_source = []
            for line in source:
                if line.startswith("#"):
                    tokens = line[1:].split()
                    if tokens[0] == "hide:":
                        if tokens[1] == "begin":
                            assert not hiding, "Missing: # hide: end"
                            hide_input_count += 1
                            hiding = True
                        else:
                            assert hiding, "Missing: # hide: begin"
                            assert tokens[1] == "end"
                            hiding = 1
                    if tokens[0] == "remove-cell":
                        remove_cell = True
                elif line.startswith("_ = "):
                    line = line[4:]

                if not hiding:
                    new_source.append(line)
                elif not isinstance(hiding, bool):
                    hiding -= 1

            assert not hiding, "Missing: # hide: end"

            while new_source and new_source[0].strip() == "":
                new_source.pop(0)
            while new_source and new_source[-1].strip() == "":
                new_source.pop()
                while new_source and new_source[-1][-1] == "\n":
                    new_source[-1] = new_source[-1][:-1]

            cell["source"] = new_source

            if not new_source:
                remove_cell = True

            if not remove_cell:
                # Fix execution counts that may have been distorted by 'remove-cell':
                execution_count += 1
                cell["execution_count"] = execution_count

        if remove_cell:
            print("Removing cell")
            remove_cell_count += 1
        else:
            new_cells.append(cell)

    if hide_input_count > 0:
        print(f"Removed {hide_input_count} sections tagged with `# hide`.")
    if remove_cell_count > 0:
        print(f"Removed {remove_cell_count} cells tagged with `# remove-cell`.")

    notebook["cells"] = new_cells
    with open(ipynb_path, "wt") as f_write:
        json.dump(notebook, f_write, indent=1)


class ShardingStrategy:
    """
    Strategy for how to shard (split) the work.
    """

    def __init__(self, spec: str) -> None:
        """
        Valid ``spec``\s are:

        - ``no``: No sharding will happen, and a single run of this script does all necessary work.
        - ``<i>/<n>``, where 0 <= i < n: Build a subset of notebooks, corresponding to job ``i`` out
          of ``n``.
        - ``collect``: Collect data generated by previous shards, and finish the work.
        """
        self.spec = spec
        if spec == "no":
            self.setup_tmp = True
            self.build_notebooks = True
            self.build_other = True
            self.shard_i = 0
            self.shard_n = 1
        elif spec == "collect":
            self.setup_tmp = False
            self.build_notebooks = False
            self.build_other = True
            self.shard_i = 0
            self.shard_n = 1
        else:
            i_str, n_str = spec.split("/")
            self.setup_tmp = False
            self.build_notebooks = True
            self.build_other = False
            self.shard_i = int(i_str)
            self.shard_n = int(n_str)
        assert 0 <= self.shard_i < self.shard_n, (self.shard_i, self.shard_n)

    def __repr__(self) -> str:
        return self.spec


def _create_fake_notebook(
    destination_relative_path: Path, limit_notebooks: Collection[str]
) -> None:
    limiting_command = f"--limit_notebooks {' '.join(limit_notebooks)}"
    print(f'Generating fake, due to: "{limiting_command}"')

    destination = _NOTEBOOKS_TMP / destination_relative_path
    title = f"Fake {destination.name}"
    title_line = "#" * len(title)

    destination.write_text(
        f"""{title}
{title_line}

Fake {destination.name} due to::

   {limiting_command}
"""
    )


def _build_notebooks(
    limit_notebooks: Optional[Collection[str]], sharding: ShardingStrategy
) -> None:
    # Building the notebooks is really slow. Let's time it so we know which notebooks we can /
    # should optimise.
    timings = []
    all_notebooks = sorted(
        chain(_NOTEBOOKS_TMP.glob("**/*.pct.py"), _NOTEBOOKS_TMP.glob("**/*.md"))
    )
    for i, source_path in enumerate(all_notebooks):
        before = perf_counter()
        print()
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        print("Building:", source_path)
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")

        source_relative_path = source_path.relative_to(_NOTEBOOKS_TMP)
        destination_relative_path = source_relative_path
        while destination_relative_path.suffix:  # .pct.py has several suffixes. Remove all of them.
            destination_relative_path = destination_relative_path.with_suffix("")
        destination_relative_path = destination_relative_path.with_suffix(".ipynb")

        if i % sharding.shard_n != sharding.shard_i:
            print("Skipping due to sharding...")
        elif limit_notebooks is None or destination_relative_path.stem in limit_notebooks:
            subprocess.run(
                [
                    "jupytext",
                    "--execute",
                    "--to",
                    "notebook",
                    "-o",
                    str(destination_relative_path),
                    str(source_relative_path),
                ],
                cwd=_NOTEBOOKS_TMP,
            ).check_returncode()
            _post_process_ipynb(destination_relative_path)
        else:
            _create_fake_notebook(destination_relative_path, limit_notebooks)

        after = perf_counter()
        timings.append((after - before, source_relative_path))

    timings.sort(reverse=True)
    print()
    print("Notebooks by build-time:")
    print(tabulate(timings, headers=["Time", "Notebook"]))
    print()


def main() -> None:
    parser = argparse.ArgumentParser(description="Build the GPflow documentation.")
    parser.add_argument(
        "branch",
        nargs="?",
        default=None,
        type=str,
        choices=[b.value for b in Branch],
        help="Git branch that is currently being built.",
    )
    parser.add_argument(
        "destination",
        nargs="?",
        default=None,
        type=Path,
        help="Directory to write docs to.",
    )
    parser.add_argument(
        "--limit_notebooks",
        "--limit-notebooks",
        type=str,
        nargs="*",
        help="Only process the notebooks with this base/stem name; this is typically much faster and is useful when debugging.  For example, to process notebooks/tailor/external-mean-function.pct.py, use --limit_notebooks external-mean-function\n",
    )
    parser.add_argument(
        "--fail_on_warning",
        "--fail-on-warning",
        default=False,
        action="store_true",
        help="If set, crash if there were any warnings while generating documentation.",
    )
    parser.add_argument(
        "--shard",
        default=ShardingStrategy("no"),
        type=ShardingStrategy,
        help=(
            "Sharding strategy:"
            " If set to 'no' this script performs all necessary work."
            " If set to the format <i>/<n>, where 0 <= i < n then this script only computes"
            f" notebooks for shard <i> out of <n> shards. This requires that {_TMP} has manually"
            " been created, and is empty."
            " If set to 'collect' then this script assumes all notebooks already have been"
            " computed, using the <i>/<n> commands, and finishes the work."
        ),
    )
    args = parser.parse_args()
    sharding = args.shard

    if sharding.setup_tmp:
        shutil.rmtree(_TMP, ignore_errors=True)
        _TMP.mkdir(parents=True)
    else:
        assert _TMP.is_dir()

    # Type-ignore below is because the `dirs_exist_ok` parameter was added in Python 3.8, and we
    # still support Python 3.7. However, we only build our documentation using Python3.10+, so
    # actually this is ok.
    # pylint: disable=unexpected-keyword-arg
    shutil.copytree(_SPHINX_SRC, _BUILD_TMP, dirs_exist_ok=True)  # type: ignore[call-arg]
    # pylint: enable=unexpected-keyword-arg

    if sharding.build_notebooks:
        _build_notebooks(args.limit_notebooks, sharding)

    if sharding.build_other:
        branch = Branch(args.branch)
        assert branch, "'branch' command line argument missing."
        dest = args.destination
        assert dest, "'destination' command line argument missing."
        version_dest = dest / branch.version
        shutil.rmtree(version_dest, ignore_errors=True)

        (_BUILD_TMP / "build_version.txt").write_text(branch.version)
        generate_module_rst(gpflow, _BUILD_TMP / "api")

        sphinx_commands = [
            "sphinx-build",
            "-b",
            "html",
            "-d",
            str(_DOCTREE_TMP),
            str(_BUILD_TMP),
            str(version_dest),
        ]
        if args.fail_on_warning:
            sphinx_commands.extend(
                [
                    "-W",
                    "--keep-going",
                ]
            )

        subprocess.run(sphinx_commands).check_returncode()


if __name__ == "__main__":
    main()
